library(cregg)
library(ggplot2)
library(ggpubr)
rm(list=ls())

#### Load data 

load("data.RData")

#### Formulas for models

f1 <- vote_choice ~ age + prof+career + gender + legeff + voteror + voterknow + smedia ## Formula


#### Compute marginal means and AMCEs
marginalmeans <- mm(data[data$country!= "Germany" & data$country!="Switzerland",], f1, id = ~id, family = binomial(link = "logit"))

amces <- cj(data[data$country!= "Germany" & data$country!="Switzerland",], f1, id = ~id)


#### Create data for Figure 1


dataPlotMarginalmeans <- as.data.frame(list(outcome = c(rep("vote_choice", 30)),
                                            feature = c(rep("age", 6), rep("prof", 6), rep("career", 3), rep("gender", 3),
                                                              rep("legeff", 3), rep("voteror", 3), rep("voterknow", 3), rep("smedia", 3)),
                                            level = c("Age", as.character(marginalmeans[1:5,]$level),
                                                            "Profession", "Buisness Person", "Farmer", "Lawyer", "Medical doctor", "Parliamentary assistant",
                                                            "Career", "Joined party 3 years ago", "Party offices in last 15 years",
                                                            "Gender", "Man", "Woman",
                                                            "Legislative Effectiveness", "Most projects accepted", "Some projects accepted",
                                                            "Presence in Constituency", "Little with voters in constituency", "Lot of time with voters in constituency",
                                                            "Knowledge of voters", "Low knowledge of voters' position", "Good knowledge of voters' position",
                                                            "Social Media", "Regularly communicate views", "Do not communicate views"),
                                            estimate = c(NA, marginalmeans[1:5,]$estimate,
                                                         NA, marginalmeans[6:10,]$estimate,
                                                         NA, marginalmeans[11:12,]$estimate,
                                                         NA, marginalmeans[13:14,]$estimate,
                                                         NA, marginalmeans[15:16,]$estimate,
                                                         NA, marginalmeans[17:18,]$estimate,
                                                         NA, marginalmeans[19:20,]$estimate,
                                                         NA, marginalmeans[21:22,]$estimate),
                                            lower = c(NA, marginalmeans[1:5,]$lower,
                                                      NA, marginalmeans[6:10,]$lower,
                                                      NA, marginalmeans[11:12,]$lower,
                                                      NA, marginalmeans[13:14,]$lower,
                                                      NA, marginalmeans[15:16,]$lower,
                                                      NA, marginalmeans[17:18,]$lower,
                                                      NA, marginalmeans[19:20,]$lower,
                                                      NA, marginalmeans[21:22,]$lower),
                                            upper = c(NA, marginalmeans[1:5,]$upper,
                                                      NA, marginalmeans[6:10,]$upper,
                                                      NA, marginalmeans[11:12,]$upper,
                                                      NA, marginalmeans[13:14,]$upper,
                                                      NA, marginalmeans[15:16,]$upper,
                                                      NA, marginalmeans[17:18,]$upper,
                                                      NA, marginalmeans[19:20,]$upper,
                                                      NA, marginalmeans[21:22,]$upper)))




dataPlotMarginalmeans$level <- factor(dataPlotMarginalmeans$level, level = c(dataPlotMarginalmeans[19:21,]$level, ## legeff
                                                                             dataPlotMarginalmeans[22:24,]$level, ## Voteror
                                                                             dataPlotMarginalmeans[25:27,]$level, ## Voterknow
                                                                             dataPlotMarginalmeans[28:30,]$level,
                                                                             dataPlotMarginalmeans[1:6,]$level, ## Age
                                                                             dataPlotMarginalmeans[16:18,]$level, ## Gender
                                                                             dataPlotMarginalmeans[13:15,]$level, ## Plo career
                                                                             dataPlotMarginalmeans[7:12,]$level)) ## social media

dataPlotMarginalmeans$feature <- factor(dataPlotMarginalmeans$feature, level = c(unique(dataPlotMarginalmeans$feature)[5:8],
                                                                                 unique(dataPlotMarginalmeans$feature)[1], 
                                                                                 unique(dataPlotMarginalmeans$feature)[4],
                                                                                 unique(dataPlotMarginalmeans$feature)[3],
                                                                                 unique(dataPlotMarginalmeans$feature)[2]))

dataPlotMarginalmeans$type <- "Marginal Means"



dataPlotAmces <- as.data.frame(list(outcome = c(rep("vote_choice", 30)),
                                            feature = c(rep("age", 6), rep("prof", 6), rep("career", 3), rep("gender", 3),
                                                        rep("legeff", 3), rep("voteror", 3), rep("voterknow", 3), rep("smedia", 3)),
                                            level = c("Age", as.character(amces[1:5,]$level),
                                                      "Profession", "Buisness Person", "Farmer", "Lawyer", "Medical doctor", "Parliamentary assistant",
                                                      "Career", "Joined party 3 years ago", "Party offices in last 15 years",
                                                      "Gender", "Man", "Woman",
                                                      "Legislative Effectiveness", "Most projects accepted", "Some projects accepted",
                                                      "Presence in Constituency", "Little with voters in constituency", "Lot of time with voters in constituency",
                                                      "Knowledge of voters", "Low knowledge of voters' position", "Good knowledge of voters' position",
                                                      "Social Media", "Regularly communicate views", "Do not communicate views"),
                                            estimate = c(NA, amces[1:5,]$estimate,
                                                         NA, amces[6:10,]$estimate,
                                                         NA, amces[11:12,]$estimate,
                                                         NA, amces[13:14,]$estimate,
                                                         NA, amces[15:16,]$estimate,
                                                         NA, amces[17:18,]$estimate,
                                                         NA, amces[19:20,]$estimate,
                                                         NA, amces[21:22,]$estimate),
                                            lower = c(NA, amces[1:5,]$lower,
                                                      NA, amces[6:10,]$lower,
                                                      NA, amces[11:12,]$lower,
                                                      NA, amces[13:14,]$lower,
                                                      NA, amces[15:16,]$lower,
                                                      NA, amces[17:18,]$lower,
                                                      NA, amces[19:20,]$lower,
                                                      NA, amces[21:22,]$lower),
                                            upper = c(NA, amces[1:5,]$upper,
                                                      NA, amces[6:10,]$upper,
                                                      NA, amces[11:12,]$upper,
                                                      NA, amces[13:14,]$upper,
                                                      NA, amces[15:16,]$upper,
                                                      NA, amces[17:18,]$upper,
                                                      NA, amces[19:20,]$upper,
                                                      NA, amces[21:22,]$upper)))




dataPlotAmces$level <- factor(dataPlotAmces$level, level = c(dataPlotAmces[19:21,]$level, ## legeff
                                                                             dataPlotAmces[22:24,]$level, ## Voteror
                                                                             dataPlotAmces[25:27,]$level, ## Voterknow
                                                                             dataPlotAmces[28:30,]$level,
                                                                             dataPlotAmces[1:6,]$level, ## Age
                                                                             dataPlotAmces[16:18,]$level, ## Gender
                                                                             dataPlotAmces[13:15,]$level, ## Plo career
                                                                             dataPlotAmces[7:12,]$level)) ## social media

dataPlotAmces$feature <- factor(dataPlotAmces$feature, level = c(unique(dataPlotAmces$feature)[5:8],
                                                                                 unique(dataPlotAmces$feature)[1], 
                                                                                 unique(dataPlotAmces$feature)[4],
                                                                                 unique(dataPlotAmces$feature)[3],
                                                                                 unique(dataPlotAmces$feature)[2]))

dataPlotAmces$type <- "Average Marginal Component Effect"


png("Figures/Figure 1.png", width = 3000, height = 3000, res=300)
ggplot(rbind(dataPlotAmces, dataPlotMarginalmeans))+
  geom_point(aes(x=estimate, y=level, color = feature))+
  geom_errorbarh(aes(y=level, xmin=lower, xmax=upper, color=feature), size=0.5, height=0.5)+
  facet_wrap(~type,
             labeller = labeller(
               type = c(`Marginal Means` = "Marginal Means", `amces` = "Average Marginal Component Effect")
             ), scales = 'free_x')+
  ylab("")+
  scale_color_manual(values = c("blue", "green", "black", "grey", "orange", "red", "dark green", "brown"), 
                     labels = c("age" = "Age", "prof"="Profession", 
                                "career"="Career", "gender"="Gender", 
                                "legeff"="Legislative Efficiency", "voteror"="Presence in constituency", 
                                "voterknow"="Knowledge of Voters", "smedia"="Social Media"))+
  geom_vline(aes(xintercept = c(rep(0, 30), rep(.5, 30))), linewidth=.1)+
  theme_minimal()+
  scale_y_discrete(limits = rev)+
  theme(legend.position = "bottom", legend.title = element_blank(), 
        panel.border = element_rect(colour = "black", fill=NA, linewidth=.5), 
        axis.text.y = element_text(face = c(rep("plain", 5), "bold",
                                            rep("plain", 2), "bold",
                                            rep("plain", 2), "bold",
                                            rep("plain", 5), "bold",
                                            rep("plain", 2), "bold",
                                            rep("plain", 2), "bold",
                                            rep("plain", 2), "bold",
                                            rep("plain", 2), "bold")))
dev.off()
                                       